{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Computing coordinates on Triton GEMM example\n",
    "\n",
    "## Introduction\n",
    "\n",
    "The original tutorial is available here:\n",
    "https://triton-lang.org/master/getting-started/tutorials/03-matrix-multiplication.html#sphx-glr-getting-started-tutorials-03-matrix-multiplication-py\n",
    "\n",
    "The most important part of the tutorial is the distribution of the work across streaming multiprocessors (`SM`) of the GPU.  \n",
    "`Triton` has a concept of `program` which are basically the smallest unit of work.  \n",
    "\n",
    "The order in which these `program`s are executed has a direct impact on the performance of the computation.  \n",
    "\n",
    "In the matmul example, `triton` authors introduce the notion of groups of `program`s.  \n",
    "It's an optimized way to execute the `program`s (partially) in parallel where several `program`s will share the same memory access.  \n",
    "\n",
    "> Memory access is the most expensive operation in the kernel execution, often much more than the computation itself.  \n",
    "\n",
    "It's only \"partially\" in parallel because `GPU`s are not always big enough to process large matrices in a single step.  \n",
    "That is why we need to split the matrix into smaller pieces and process only some pieces in parallel.\n",
    "\n",
    "The part of the tutorial we are interested in:\n",
    "\n",
    "```python\n",
    "# program ID\n",
    "pid = tl.program_id(axis=0)\n",
    "# number of program ids along the M axis\n",
    "num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n",
    "# number of programs ids along the N axis\n",
    "num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n",
    "# number of programs in group\n",
    "num_pid_in_group = GROUP_SIZE_M * num_pid_n\n",
    "# id of the group this program is in\n",
    "group_id = pid // num_pid_in_group\n",
    "# row id of the first program in the group\n",
    "first_pid_m = group_id * GROUP_SIZE_M\n",
    "# if `num_pid_m` isn't divisible by `GROUP_SIZE_M`, the last group is smaller\n",
    "group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n",
    "# *within groups*, programs are ordered in a column-major order\n",
    "# row id of the program in the *launch grid*\n",
    "pid_m = first_pid_m + (pid % group_size_m)\n",
    "# col-id of the program in the *launch grid*\n",
    "pid_n = (pid % num_pid_in_group) // group_size_m\n",
    "```\n",
    "\n",
    "## Initialization\n",
    "\n",
    "We will try to redo the computation in `torch` to see what `triton` code above means.\n",
    "\n",
    "> Reminder: GEMM problem are usually presented as `MNK` problem where matrix `A` is `MxK` (input), matrix `B` is `KxN` \n",
    "> (input) and matrix `C` is `MxN` (output)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "C [1024x768] = A [1024x128] * B [128x768]\n",
      "num_pid_m: 8\n",
      "num_pid_n: 12\n",
      "num_pid_n * GROUP_SIZE_M = 24\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "\n",
    "def cdiv(x, y):\n",
    "    \"\"\"\n",
    "    Ceiling division returns the closest integer greater than or equal to the quotient.\n",
    "    \"\"\"\n",
    "    return (x + y - 1) // y\n",
    "\n",
    "\n",
    "M = 1024\n",
    "N = 768\n",
    "K = 128\n",
    "\n",
    "print(f\"C [{M}x{N}] = A [{M}x{K}] * B [{K}x{N}]\")\n",
    "\n",
    "a = torch.rand(M, K)\n",
    "b = torch.rand(K, N)\n",
    "c = torch.rand(M, N)\n",
    "\n",
    "# programs work at the tile level, below are their dimensions for each axis\n",
    "BLOCK_SIZE_M = 128\n",
    "BLOCK_SIZE_N = 64\n",
    "BLOCK_SIZE_K = 32\n",
    "\n",
    "# group is a special concept to speed up matmul in this tutorial\n",
    "GROUP_SIZE_M = 2\n",
    "\n",
    "# # of programs to run on each C axis (each program will iterate over K)\n",
    "num_pid_m = cdiv(M, BLOCK_SIZE_M)  # number of `program`s in M dimension, rounded to the nearest bigger integer\n",
    "num_pid_n = cdiv(N, BLOCK_SIZE_N)  # number of `program`s in N dimension, rounded to the nearest bigger integer\n",
    "\n",
    "print(\"num_pid_m:\", num_pid_m)\n",
    "print(\"num_pid_n:\", num_pid_n)\n",
    "print(\"num_pid_n * GROUP_SIZE_M =\", num_pid_n * GROUP_SIZE_M)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The main trick to limit global memory (`GM` aka the DDRAM) accesses is to have parallel programs that require the same part of the matrix.\n",
    "\n",
    "Each `program` has a unique id, which is the index of the program in the list of `program`s.  \n",
    "As each `program` will  iterate over the `K` axis, the formula below (from the tutorial) will provide us the total number of `program`s to launch:\n",
    "\n",
    "```python\n",
    "grid = lambda META: (\n",
    "    triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),\n",
    ")\n",
    "```\n",
    "\n",
    "Below, we redo the computation in `torch`with the variables defined above."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "nb programs to launch: 96\n"
     ]
    }
   ],
   "source": [
    "nb_programs = num_pid_m * num_pid_n  # number of programs to launch\n",
    "print(\"nb programs to launch:\", nb_programs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# some example program ID\n",
    "pid = 60\n",
    "assert (\n",
    "    pid < nb_programs\n",
    "), f\"we will launch a {num_pid_m}x{num_pid_n}={nb_programs} grid of programs, pid={pid} is too big\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We need at least `num_pid_m` `program`s to cover the `M` axis (number of rows of matrix `A`) and `num_pid_n` `program`s\n",
    "to cover the `N` axis (number of columns of matrix `B`).  \n",
    "Each of those `program`s will have to iterate over the `K` axis of `A` and `B` matrices.  \n",
    "\n",
    "The goal of our strategy is to reuse data as much as possible.  \n",
    "For that, we will run in parallel several `program`s that need the same data from one of the input matrix.  \n",
    "\n",
    "Each `program` consumes `A` and `B` matrices.  \n",
    "We can reuse either data from `A` or `B`, we arbitrarily choose to reuse data from `B`.  \n",
    "It means that we run several `program`s on different rows of `A` (axis `M`) matrix that consume the same column from `B` (axis `N`).  \n",
    "\n",
    "> reminder: each `program` is responsible to iterate over `K` axis\n",
    "\n",
    "If we express that logic in pseudo-code, it would look like a nested loop where the outer loop (axis `N` of `B` matrix) is mostly serially iterated  \n",
    " and the inner loop (axis `M` of `A` matrix) is parallelized.\n",
    "\n",
    "<!-- TODO range are not true -->\n",
    "```python\n",
    "for pos_n in range(num_pid_n):  # serialized iteration\n",
    "    for pos_m in range(num_pid_m):  # GROUP_SIZE_M programs in parallel\n",
    "        # each program is associated with a position on the M and N axis and will iterate itself over the K axis\n",
    "        do_work()\n",
    "```\n",
    "\n",
    "Each group of `GROUP_SIZE_M` `program`s (which are positioned on `GROUP_SIZE_M` different rows on the `M` axis of `A`) \n",
    "will consume the same complete column of the `N` axis of `B` matrix before switching to the next group.  \n",
    "We need `GROUP_SIZE_M * num_pid_n` `program`s to finish the computation of a single `C` matrix block."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "24\n"
     ]
    }
   ],
   "source": [
    "num_pid_in_group = GROUP_SIZE_M * num_pid_n\n",
    "assert num_pid_n * GROUP_SIZE_M <= M\n",
    "print(num_pid_in_group)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If we need `num_pid_in_group` `program`s to process a single `C` block, we can guess the group id of the current program:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2\n"
     ]
    }
   ],
   "source": [
    "group_id = pid // num_pid_in_group\n",
    "print(group_id)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we will:\n",
    "* compute the `pid` of the first program in our group;\n",
    "* compute the real size of the group, we want to catch the case of the last group of the row when its dimension is inferior to the others.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "first_pid_m: 4\n",
      "group_size_m: 2\n"
     ]
    }
   ],
   "source": [
    "# row-id of the first program in the group\n",
    "first_pid_m = group_id * GROUP_SIZE_M\n",
    "\n",
    "print(\"first_pid_m:\", first_pid_m)\n",
    "# if `num_pid_m` isn't divisible by `GROUP_SIZE_M`, the last group is smaller\n",
    "group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n",
    "print(\"group_size_m:\", group_size_m)\n",
    "\n",
    "assert group_size_m > 0"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We know which group we are part of, so we just need to find our spot in this group.  \n",
    "> As noted in the original comments of the tutorial *\"within groups, programs are ordered in a column-major order \n",
    "> row-id of the program in the launch grid\"*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "pid_m: 4\n",
      "pid_n: 6\n"
     ]
    }
   ],
   "source": [
    "pid_m = first_pid_m + (pid % group_size_m)\n",
    "print(\"pid_m:\", pid_m)\n",
    "\n",
    "pid_n = (pid % num_pid_in_group) // group_size_m\n",
    "print(\"pid_n:\", pid_n)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To conclude this explanation, we will see how `pid_m` and `pid_n` are used to generate memroy offsets to read / write.\n",
    "\n",
    "> Code below doesn't make sense in the `triton` context and are just here for the sake of completeness.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "a_ptrs\n",
      " tensor([[75261632, 75261633, 75261634,  ..., 75261661, 75261662, 75261663],\n",
      "        [75261760, 75261761, 75261762,  ..., 75261789, 75261790, 75261791],\n",
      "        [75261888, 75261889, 75261890,  ..., 75261917, 75261918, 75261919],\n",
      "        ...,\n",
      "        [75277632, 75277633, 75277634,  ..., 75277661, 75277662, 75277663],\n",
      "        [75277760, 75277761, 75277762,  ..., 75277789, 75277790, 75277791],\n",
      "        [75277888, 75277889, 75277890,  ..., 75277917, 75277918, 75277919]])\n",
      "b_ptrs\n",
      " tensor([[75722624, 75722625, 75722626,  ..., 75722685, 75722686, 75722687],\n",
      "        [75723392, 75723393, 75723394,  ..., 75723453, 75723454, 75723455],\n",
      "        [75724160, 75724161, 75724162,  ..., 75724221, 75724222, 75724223],\n",
      "        ...,\n",
      "        [75744896, 75744897, 75744898,  ..., 75744957, 75744958, 75744959],\n",
      "        [75745664, 75745665, 75745666,  ..., 75745725, 75745726, 75745727],\n",
      "        [75746432, 75746433, 75746434,  ..., 75746493, 75746494, 75746495]])\n"
     ]
    }
   ],
   "source": [
    "# torch semantic to retrieve data array pointer\n",
    "a_ptr = a.storage().data_ptr()\n",
    "b_ptr = b.storage().data_ptr()\n",
    "c_ptr = c.storage().data_ptr()\n",
    "\n",
    "# stride is the memory offset between two consecutive rows / columns\n",
    "stride_am, stride_ak = a.stride()\n",
    "stride_bk, stride_bn = b.stride()\n",
    "stride_cm, stride_cn = c.stride()\n",
    "\n",
    "# below we perform the conversion from starting pointer to a matrix of pointers\n",
    "offs_am = pid_m * BLOCK_SIZE_M + torch.arange(0, BLOCK_SIZE_M)\n",
    "offs_bn = pid_n * BLOCK_SIZE_N + torch.arange(0, BLOCK_SIZE_N)\n",
    "offs_k = torch.arange(0, BLOCK_SIZE_K)\n",
    "\n",
    "# broadcasting is leveraged to compute the offsets\n",
    "a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)\n",
    "b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)\n",
    "\n",
    "print(\"a_ptrs\\n\", a_ptrs)\n",
    "print(\"b_ptrs\\n\", b_ptrs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "outputs": [],
   "source": [],
   "metadata": {
    "collapsed": false
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.9.13 ('venv': venv)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "9304613f370859586d5dde245ba17471fe0e4bae74a871b63d1672bdb9f882ed"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
